{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# HardSigmoid\n",
    "\n",
    "参考：[HardSigmoid](https://onnx.ai/onnx/operators/onnx__HardSigmoid.html)\n",
    "\n",
    "HardSigmoid 函数\n",
    "\n",
    "$$\n",
    "\\operatorname{HardSigmoid}(x) = \\max(0, \\min(1, \\alpha x + \\beta))，\n",
    "$$\n",
    "\n",
    "通常情况 $\\alpha=0.2$ 和 $\\beta=0.5$。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "temp_dir = Path(\".temp\")\n",
    "temp_dir.mkdir(exist_ok=True)\n",
    "model_path = f\"{temp_dir}/HardSigmoid.onnx\" # 模型存储路径"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "构建模型："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 87,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "model: Already defined.\n"
     ]
    }
   ],
   "source": [
    "from onnxscript import opset20 as op\n",
    "from onnxscript import ir\n",
    "from onnxscript import script\n",
    "from onnxscript import FLOAT\n",
    "import onnx\n",
    "\n",
    "@script()\n",
    "def model(x: FLOAT[1, 3, 4, 4]) -> FLOAT[1, 3, 4, 4]:\n",
    "    x = op.Add(x, x)\n",
    "    # y = op.HardSigmoid(x)\n",
    "    # x = op.Mul(x, y)\n",
    "    # x = op.HardSigmoid(x)\n",
    "    return x\n",
    "    # return op.HardSwish(x,)\n",
    "\n",
    "onnx.save_model(model.to_model_proto(), model_path,)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 转换为 Relay 模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #AA22FF\">@main</span>(<span style=\"color: #AA22FF; font-weight: bold\">%</span>x: Tensor[(<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">4</span>, <span style=\"color: #008000\">4</span>), float32] <span style=\"color: #AA22FF; font-weight: bold\">/*</span> ty<span style=\"color: #AA22FF; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">4</span>, <span style=\"color: #008000\">4</span>), float32] span<span style=\"color: #AA22FF; font-weight: bold\">=</span>n0<span style=\"color: #AA22FF; font-weight: bold\">.</span>x:<span style=\"color: #008000\">0</span>:<span style=\"color: #008000\">0</span> <span style=\"color: #AA22FF; font-weight: bold\">*/</span>) <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> Tensor[(<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">4</span>, <span style=\"color: #008000\">4</span>), float32] {\n",
       "  add(<span style=\"color: #AA22FF; font-weight: bold\">%</span>x, <span style=\"color: #AA22FF; font-weight: bold\">%</span>x) <span style=\"color: #AA22FF; font-weight: bold\">/*</span> ty<span style=\"color: #AA22FF; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">4</span>, <span style=\"color: #008000\">4</span>), float32] span<span style=\"color: #AA22FF; font-weight: bold\">=</span>n0:<span style=\"color: #008000\">0</span>:<span style=\"color: #008000\">0</span> <span style=\"color: #AA22FF; font-weight: bold\">*/</span>\n",
       "}\n",
       "</pre></div>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from tvm.driver.tvmc.frontends import load_model\n",
    "\n",
    "model = load_model(model_path)\n",
    "model.mod.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 构建模板"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 89,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tvm.relay.dataflow_pattern import (\n",
    "    TupleGetItemPattern, is_op, wildcard,\n",
    "    is_constant, rewrite,\n",
    "    DFPatternCallback\n",
    ")\n",
    "import tvm\n",
    "from tvm.relay import op as _op\n",
    "from tvm.relay import transform as _transform\n",
    "from tvm import relay\n",
    "\n",
    "def make_hard_sigmoid_pattern():\n",
    "    r\"\"\"匹配 ONNX HardSigmoid 算子的模式\"\"\"\n",
    "    x = wildcard()\n",
    "    alpha = is_constant()\n",
    "    x = is_op(\"multiply\")(x, alpha)\n",
    "    beta = is_constant()\n",
    "    x = is_op(\"add\")(x, beta)\n",
    "    x = is_op(\"clip\")(x)\n",
    "    return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 90,
   "metadata": {},
   "outputs": [],
   "source": [
    "compiler = \"ccompiler\"\n",
    "pattern_table = [(f\"{compiler}.hard_sigmoid\", make_hard_sigmoid_pattern())]\n",
    "\n",
    "\n",
    "seq = tvm.transform.Sequential(\n",
    "    [\n",
    "        _transform.MergeComposite(pattern_table),\n",
    "        # _transform.AnnotateTarget(compiler),\n",
    "        # _transform.MergeCompilerRegions(),\n",
    "        # _transform.PartitionGraph(),\n",
    "        # _transform.InferType(),\n",
    "        # _transform.Inline(),\n",
    "        # _transform.DefuseOps()\n",
    "    ]\n",
    ")\n",
    "mod = seq(model.mod)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 91,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #AA22FF\">@main</span>(<span style=\"color: #AA22FF; font-weight: bold\">%</span>x: Tensor[(<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">4</span>, <span style=\"color: #008000\">4</span>), float32] <span style=\"color: #AA22FF; font-weight: bold\">/*</span> ty<span style=\"color: #AA22FF; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">4</span>, <span style=\"color: #008000\">4</span>), float32] span<span style=\"color: #AA22FF; font-weight: bold\">=</span>n0<span style=\"color: #AA22FF; font-weight: bold\">.</span>x:<span style=\"color: #008000\">0</span>:<span style=\"color: #008000\">0</span> <span style=\"color: #AA22FF; font-weight: bold\">*/</span>) <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> Tensor[(<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">4</span>, <span style=\"color: #008000\">4</span>), float32] {\n",
       "  add(<span style=\"color: #AA22FF; font-weight: bold\">%</span>x, <span style=\"color: #AA22FF; font-weight: bold\">%</span>x) <span style=\"color: #AA22FF; font-weight: bold\">/*</span> ty<span style=\"color: #AA22FF; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">4</span>, <span style=\"color: #008000\">4</span>), float32] span<span style=\"color: #AA22FF; font-weight: bold\">=</span>n0:<span style=\"color: #008000\">0</span>:<span style=\"color: #008000\">0</span> <span style=\"color: #AA22FF; font-weight: bold\">*/</span>\n",
       "}\n",
       "</pre></div>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "mod.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 92,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "def get_calibration_dataset(mod, input_name):\n",
    "    dataset = []\n",
    "    input_shape = [int(x) for x in mod[\"main\"].checked_type.arg_types[0].shape]\n",
    "    for i in range(5):\n",
    "        data = np.random.uniform(size=input_shape)\n",
    "        dataset.append({input_name: data})\n",
    "    return dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 93,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #AA22FF\">@main</span>(<span style=\"color: #AA22FF; font-weight: bold\">%</span>x: Tensor[(<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">4</span>, <span style=\"color: #008000\">4</span>), float32] <span style=\"color: #AA22FF; font-weight: bold\">/*</span> ty<span style=\"color: #AA22FF; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">4</span>, <span style=\"color: #008000\">4</span>), float32] span<span style=\"color: #AA22FF; font-weight: bold\">=</span>n0<span style=\"color: #AA22FF; font-weight: bold\">.</span>x:<span style=\"color: #008000\">0</span>:<span style=\"color: #008000\">0</span> <span style=\"color: #AA22FF; font-weight: bold\">*/</span>) <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> Tensor[(<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">4</span>, <span style=\"color: #008000\">4</span>), float32] {\n",
       "  add(<span style=\"color: #AA22FF; font-weight: bold\">%</span>x, <span style=\"color: #AA22FF; font-weight: bold\">%</span>x) <span style=\"color: #AA22FF; font-weight: bold\">/*</span> ty<span style=\"color: #AA22FF; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">4</span>, <span style=\"color: #008000\">4</span>), float32] span<span style=\"color: #AA22FF; font-weight: bold\">=</span>n0:<span style=\"color: #008000\">0</span>:<span style=\"color: #008000\">0</span> <span style=\"color: #AA22FF; font-weight: bold\">*/</span>\n",
       "}\n",
       "</pre></div>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "dataset = get_calibration_dataset(mod, \"x\")\n",
    "with tvm.transform.PassContext(opt_level=3):\n",
    "    with relay.quantize.qconfig(\n",
    "        skip_conv_layers=[],\n",
    "        # calibrate_mode=\"kl_divergence\", \n",
    "        weight_scale=\"max\",\n",
    "        round_for_shift=True,\n",
    "        # rounding=\"TONEAREST\", # \"UPWARD\" or \"TONEAREST\"\n",
    "        # calibrate_skip_layers=[],\n",
    "        skip_dense_layer=False,\n",
    "    ):\n",
    "        qmod = relay.quantize.quantize(model.mod, model.params, dataset)\n",
    "qmod.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ai",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
